{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# python notebook for Make Your Own Neural Network\n",
    "# code for a 3-layer neural network, and code for learning the MNIST dataset\n",
    "# this version trains using the MNIST dataset, then tests on our own images\n",
    "# (c) Tariq Rashid, 2016\n",
    "# license is GPLv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "# scipy.special for the sigmoid function expit()\n",
    "import scipy.special\n",
    "# library for plotting arrays\n",
    "import matplotlib.pyplot\n",
    "# ensure the plots are inside this notebook, not an external window\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper to load data from PNG image files\n",
    "import imageio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# neural network class definition\n",
    "class neuralNetwork:\n",
    "    \n",
    "    \n",
    "    # initialise the neural network\n",
    "    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):\n",
    "        # set number of nodes in each input, hidden, output layer\n",
    "        self.inodes = inputnodes\n",
    "        self.hnodes = hiddennodes\n",
    "        self.onodes = outputnodes\n",
    "        \n",
    "        # link weight matrices, wih and who\n",
    "        # weights inside the arrays are w_i_j, where link is from node i to node j in the next layer\n",
    "        # w11 w21\n",
    "        # w12 w22 etc \n",
    "        self.wih = numpy.random.normal(0.0, pow(self.inodes, -0.5), (self.hnodes, self.inodes))\n",
    "        self.who = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.onodes, self.hnodes))\n",
    "\n",
    "        # learning rate\n",
    "        self.lr = learningrate\n",
    "        \n",
    "        # activation function is the sigmoid function\n",
    "        self.activation_function = lambda x: scipy.special.expit(x)\n",
    "        \n",
    "        pass\n",
    "\n",
    "    \n",
    "    # train the neural network\n",
    "    def train(self, inputs_list, targets_list):\n",
    "        # convert inputs list to 2d array\n",
    "        inputs = numpy.array(inputs_list, ndmin=2).T\n",
    "        targets = numpy.array(targets_list, ndmin=2).T\n",
    "        \n",
    "        # calculate signals into hidden layer\n",
    "        hidden_inputs = numpy.dot(self.wih, inputs)\n",
    "        # calculate the signals emerging from hidden layer\n",
    "        hidden_outputs = self.activation_function(hidden_inputs)\n",
    "        \n",
    "        # calculate signals into final output layer\n",
    "        final_inputs = numpy.dot(self.who, hidden_outputs)\n",
    "        # calculate the signals emerging from final output layer\n",
    "        final_outputs = self.activation_function(final_inputs)\n",
    "        \n",
    "        # output layer error is the (target - actual)\n",
    "        output_errors = targets - final_outputs\n",
    "        # hidden layer error is the output_errors, split by weights, recombined at hidden nodes\n",
    "        hidden_errors = numpy.dot(self.who.T, output_errors) \n",
    "        \n",
    "        # update the weights for the links between the hidden and output layers\n",
    "        self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)), numpy.transpose(hidden_outputs))\n",
    "        \n",
    "        # update the weights for the links between the input and hidden layers\n",
    "        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), numpy.transpose(inputs))\n",
    "        \n",
    "        pass\n",
    "\n",
    "    \n",
    "    # query the neural network\n",
    "    def query(self, inputs_list):\n",
    "        # convert inputs list to 2d array\n",
    "        inputs = numpy.array(inputs_list, ndmin=2).T\n",
    "        \n",
    "        # calculate signals into hidden layer\n",
    "        hidden_inputs = numpy.dot(self.wih, inputs)\n",
    "        # calculate the signals emerging from hidden layer\n",
    "        hidden_outputs = self.activation_function(hidden_inputs)\n",
    "        \n",
    "        # calculate signals into final output layer\n",
    "        final_inputs = numpy.dot(self.who, hidden_outputs)\n",
    "        # calculate the signals emerging from final output layer\n",
    "        final_outputs = self.activation_function(final_inputs)\n",
    "        \n",
    "        return final_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of input, hidden and output nodes\n",
    "input_nodes = 784\n",
    "hidden_nodes = 200\n",
    "output_nodes = 10\n",
    "\n",
    "# learning rate\n",
    "learning_rate = 0.1\n",
    "\n",
    "# create instance of neural network\n",
    "n = neuralNetwork(input_nodes,hidden_nodes,output_nodes, learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the mnist training data CSV file into a list\n",
    "training_data_file = open(\"mnist_dataset/mnist_train.csv\", 'r')\n",
    "training_data_list = training_data_file.readlines()\n",
    "training_data_file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train the neural network\n",
    "\n",
    "# epochs is the number of times the training data set is used for training\n",
    "epochs = 10\n",
    "\n",
    "for e in range(epochs):\n",
    "    # go through all records in the training data set\n",
    "    for record in training_data_list:\n",
    "        # split the record by the ',' commas\n",
    "        all_values = record.split(',')\n",
    "        # scale and shift the inputs\n",
    "        inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01\n",
    "        # create the target output values (all 0.01, except the desired label which is 0.99)\n",
    "        targets = numpy.zeros(output_nodes) + 0.01\n",
    "        # all_values[0] is the target label for this record\n",
    "        targets[int(all_values[0])] = 0.99\n",
    "        n.train(inputs, targets)\n",
    "        pass\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "test with our own image "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading ... my_own_images/2828_my_own_image.png\n",
      "min =  0.01\n",
      "max =  1.0\n",
      "[[0.03762286]\n",
      " [0.01202177]\n",
      " [0.00885076]\n",
      " [0.90888382]\n",
      " [0.04628071]\n",
      " [0.02848814]\n",
      " [0.00864396]\n",
      " [0.11570917]\n",
      " [0.17224781]\n",
      " [0.02856818]]\n",
      "network says  3\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAD1RJREFUeJzt3V+MVGWax/Hfs8gEUYgSWyQCMhoxiomyqRgjq3EdFGclQVAMXhhIxunRYNxJvFgDF6OJEiTCrBcbIyoZSBhGjeOIou4YsopjBmJhzMAMuuOfZmD514aJw1xI8+fZiz7uttjnraLqVJ3qfr6fhFTVec7b9Vjx16eq3zrnNXcXgHj+oewGAJSD8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCOqMdj7Zeeed51OmTGnnUwKh9PT06Msvv7R69m0q/GZ2q6SnJI2Q9Jy7L0/tP2XKFFWr1WaeEkBCpVKpe9+G3/ab2QhJ/yHph5KukHS3mV3R6M8D0F7NfOa/RtKn7v65u/dJ+pWkOcW0BaDVmgn/hZL2DHi8N9v2LWbWbWZVM6v29vY28XQAitRM+Af7o8J3zg9299XuXnH3SldXVxNPB6BIzYR/r6RJAx5PlLSvuXYAtEsz4f9A0qVm9n0z+56kBZI2FtMWgFZreKrP3Y+b2QOS/lP9U31r3P2PhXUGoKWamud39zckvVFQLwDaiK/3AkERfiAowg8ERfiBoAg/EBThB4Jq6/n8GNyBAweS9WeffTZZ37lzZ25t//79ybFHjx5N1s3Sp4bXqqdWhLrsssuSY2+55ZZkfe7cucn66NGjk/XoOPIDQRF+ICjCDwRF+IGgCD8QFOEHgmKqrw22bduWrN9www3Jel9fX5HtdIxar8u6deuS9TFjxiTrb731Vm7tuuuuS46NgCM/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTFPH8Bap2Se/311yfrx48fT9ZXrFiRrC9YsCC3Nnbs2OTYESNGJOu1TtmtJXVK7+7du5Nja83zr1y5MlmfOXNmbm3v3r3JsePGjUvWhwOO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QVFPz/GbWI+mIpBOSjrt7pYimOtHJkydza/fff39y7LFjx5L1Wpfmvvfee5P1oWratGnJ+hNPPJGsX3TRRcn64sWLc2sPPfRQcuyaNWuS9Wa//9AJiviSzz+7+5cF/BwAbcTbfiCoZsPvkn5rZtvNrLuIhgC0R7Nv+2e4+z4zO1/S22b2sbtvGbhD9kuhW5ImT57c5NMBKEpTR35335fdHpL0iqRrBtlntbtX3L3S1dXVzNMBKFDD4Tezs8xszDf3Jd0iKX/FSAAdpZm3/eMlvZJNeZwh6Zfunn+tZAAdxVLnWxetUql4tVpt2/MVaenSpbm1ZcuWJcdee+21yfp7772XrJ9xBpddGEyt709ccsklubU9e/Ykx37yySfJ+tSpU5P1slQqFVWr1bq+hMBUHxAU4QeCIvxAUIQfCIrwA0ERfiAo5pDq9M477+TWxo8fnxy7adOmZJ2pvMaMHDkyWV+1alVubf78+cmxtS4b/thjjyXrQwFHfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IignmOm3evDm3lrqstySNHj266HZQh2YuG9fb21tgJ52JIz8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMU8f51GjRpVdgs4xddff52sN7O0+axZsxoeO1Rw5AeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoGrO85vZGkmzJR1y9yuzbeMkvSBpiqQeSXe5+19b1yYi6uvrS9bnzZuXrO/YsSO3tmjRouTYuXPnJuvDQT1H/l9IuvWUbQ9L2uzul0ranD0GMITUDL+7b5F0+JTNcyStze6vlXR7wX0BaLFGP/OPd/f9kpTdnl9cSwDaoeV/8DOzbjOrmlk1wnXRgKGi0fAfNLMJkpTdHsrb0d1Xu3vF3StdXV0NPh2AojUa/o2SFmb3F0p6tZh2ALRLzfCb2QZJv5d0mZntNbMfSVou6WYz+7Okm7PHAIaQmvP87n53TukHBfeCIajWmgVHjx7Nre3cuTM59sEHH0zWt27dmqzPnDkzt/b0008nx5pZsj4c8A0/ICjCDwRF+IGgCD8QFOEHgiL8QFBcunsYOHbsWG7t5ZdfTo7dsGFDsv7FF18k6/v27UvWDx8+9Zyw/+fuybG13HHHHcn6unXrcmtcip0jPxAW4QeCIvxAUIQfCIrwA0ERfiAowg8ExTz/MPD444/n1h599NE2dlKsc845J1m/6aabkvXUEt6jR49uqKfhhCM/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTFPP8wMGfOnNzarl27kmObvUR1M+fkb9myJVk/cOBAsr548eJkfenSpbm17du3J8defPHFyfpwwJEfCIrwA0ERfiAowg8ERfiBoAg/EBThB4KqOc9vZmskzZZ0yN2vzLY9IunHknqz3Za4+xutahJp06dPz6298MILbezk9Jw4cSJZ/+yzz5L1559/PllfsWJFbq1SqSTHfv7558l6rWsNDAX1HPl/IenWQbb/3N2vzv4RfGCIqRl+d98iKX/ZFQBDUjOf+R8wsz+Y2RozO7ewjgC0RaPhf1rSJZKulrRf0sq8Hc2s28yqZlbt7e3N2w1AmzUUfnc/6O4n3P2kpGclXZPYd7W7V9y90tXV1WifAArWUPjNbMKAh3Ml7SymHQDtUs9U3wZJN0o6z8z2SvqZpBvN7GpJLqlH0k9a2COAFqgZfne/e5DN6QlWoA4jRoxI1qdOnZqsL1++PFk/fvx4bm3VqlXJsa+99lqyfs899yTrQwHf8AOCIvxAUIQfCIrwA0ERfiAowg8ExaW7MWTVuuz47Nmzc2u1pvp6enoaaWlI4cgPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exz49hq9YpwyknT54ssJPOxJEfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Jinh/D1vvvv9/w2EmTJhXYSWfiyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQdWc5zezSZLWSbpA0klJq939KTMbJ+kFSVMk9Ui6y93/2rpWgW/r6+tL1p988sncWq1z/e+8886GehpK6jnyH5f0kLtfLulaSYvN7ApJD0va7O6XStqcPQYwRNQMv7vvd/cPs/tHJO2SdKGkOZLWZrutlXR7q5oEULzT+sxvZlMkTZe0TdJ4d98v9f+CkHR+0c0BaJ26w29mZ0t6WdJP3f1vpzGu28yqZlbt7e1tpEcALVBX+M1spPqDv97df51tPmhmE7L6BEmHBhvr7qvdveLula6uriJ6BlCAmuG3/qVQn5e0y90HLm26UdLC7P5CSa8W3x6AVqnnlN4Zku6RtMPMPsq2LZG0XNKLZvYjSX+RNL81LbbHsWPHkvWNGzfm1mbNmpUce/bZZzfUU3TunqwvWbIkWT98+HBu7b777kuOHTt2bLI+HNQMv7v/TlLeQug/KLYdAO3CN/yAoAg/EBThB4Ii/EBQhB8IivADQXHp7ky1Wk3WU6d4zp49Ozn2pZdeStZHjRqVrEf14osvJusrV65M1i+44ILc2ooVKxrqaTjhyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQTHPn5k+fXqyfvnll+fWXn/99eTYefPmJevPPPNMsj5x4sRkvf96K50pdZ2E9evXJ8d2d3cn67X+u999993c2pgxY5JjI+DIDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBMc+fqXVO/fbt23Nrta7b/+abbybrkydPTtbPPPPMZP22227Lrc2YMSM5ttZS1bV89dVXyfpzzz2XW9u9e3dy7MiRI5P1TZs2JetTp05N1qPjyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQVmtNdDNbJKkdZIukHRS0mp3f8rMHpH0Y0m92a5L3P2N1M+qVCpe6/r4Q1FfX1+yvnbt2mS91nX9t27dmqwfOXIkWS9T6pz7RYsWJccuW7YsWU9dlz+qSqWiarVa1wUe6vmSz3FJD7n7h2Y2RtJ2M3s7q/3c3Z9stFEA5akZfnffL2l/dv+Ime2SdGGrGwPQWqf1md/MpkiaLmlbtukBM/uDma0xs3NzxnSbWdXMqr29vYPtAqAEdYffzM6W9LKkn7r73yQ9LekSSVer/53BoAunuftqd6+4e6Wrq6uAlgEUoa7wm9lI9Qd/vbv/WpLc/aC7n3D3k5KelXRN69oEULSa4bf+P9c+L2mXu68asH3CgN3mStpZfHsAWqWeqb5/kvSepB3qn+qTpCWS7lb/W36X1CPpJ9kfB3MN16m+Vjtx4kSy/vHHH+fW9uzZkxzb7GW/a50SfNVVV+XW+BhYvEKn+tz9d5IG+2HJOX0AnY1v+AFBEX4gKMIPBEX4gaAIPxAU4QeC4tLdQ0CtufRp06Y1VENsHPmBoAg/EBThB4Ii/EBQhB8IivADQRF+IKia5/MX+mRmvZIGrst8nqQv29bA6enU3jq1L4neGlVkbxe5e10XSmhr+L/z5GZVd6+U1kBCp/bWqX1J9NaosnrjbT8QFOEHgio7/KtLfv6UTu2tU/uS6K1RpfRW6md+AOUp+8gPoCSlhN/MbjWzT8zsUzN7uIwe8phZj5ntMLOPzKzU64xny6AdMrOdA7aNM7O3zezP2e2gy6SV1NsjZvY/2Wv3kZn9S0m9TTKz/zKzXWb2RzP712x7qa9doq9SXre2v+03sxGS/lvSzZL2SvpA0t3u/qe2NpLDzHokVdy99DlhM7tB0t8lrXP3K7NtKyQddvfl2S/Oc9393zqkt0ck/b3slZuzBWUmDFxZWtLtkhapxNcu0dddKuF1K+PIf42kT939c3fvk/QrSXNK6KPjufsWSYdP2TxH0trs/lr1/8/Tdjm9dQR33+/uH2b3j0j6ZmXpUl+7RF+lKCP8F0oauIzMXnXWkt8u6bdmtt3MustuZhDjv1kZKbs9v+R+TlVz5eZ2OmVl6Y557RpZ8bpoZYR/sNV/OmnKYYa7/6OkH0panL29RX3qWrm5XQZZWbojNLriddHKCP9eSZMGPJ4oaV8JfQzK3fdlt4ckvaLOW3344DeLpGa3h0ru5/900srNg60srQ547Tppxesywv+BpEvN7Ptm9j1JCyRtLKGP7zCzs7I/xMjMzpJ0izpv9eGNkhZm9xdKerXEXr6lU1ZuzltZWiW/dp224nUpX/LJpjL+XdIISWvc/fG2NzEIM7tY/Ud7qf/Kxr8sszcz2yDpRvWf9XVQ0s8k/UbSi5ImS/qLpPnu3vY/vOX0dqNOc+XmFvWWt7L0NpX42hW54nUh/fANPyAmvuEHBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCo/wWgVWZ/YILUIAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# test the neural network with our own images\n",
    "\n",
    "# load image data from png files into an array\n",
    "print (\"loading ... my_own_images/2828_my_own_image.png\")\n",
    "img_array = imageio.imread('my_own_images/2828_my_own_image.png', as_gray=True)\n",
    "    \n",
    "# reshape from 28x28 to list of 784 values, invert values\n",
    "img_data  = 255.0 - img_array.reshape(784)\n",
    "    \n",
    "# then scale data to range from 0.01 to 1.0\n",
    "img_data = (img_data / 255.0 * 0.99) + 0.01\n",
    "print(\"min = \", numpy.min(img_data))\n",
    "print(\"max = \", numpy.max(img_data))\n",
    "\n",
    "# plot image\n",
    "matplotlib.pyplot.imshow(img_data.reshape(28,28), cmap='Greys', interpolation='None')\n",
    "\n",
    "# query the network\n",
    "outputs = n.query(img_data)\n",
    "print (outputs)\n",
    "\n",
    "# the index of the highest value corresponds to the label\n",
    "label = numpy.argmax(outputs)\n",
    "print(\"network says \", label)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
